import mindspore.nn as nn
import numpy as np
from mindspore import Tensor
from mindspore.ops import operations as P
from src.common import Conv2d, ConvBNReLU
from src.fluff import Fluff


#
# class EMA(nn.Cell):
#     """The Expectation-Maximization Attention Unit (EMAU).
#     Arguments:
#         in_channels (int): The input and output channel number.
#         reduce_ratio (int):
#         stage_num (int): The iteration number for EM.
#     """
#
#     def __init__(self, in_channels, reduce_ratio=2, stage_num=3, mu_momentum=0.9):
#         super(EMA, self).__init__()
#         self.stage_num = stage_num
#         k = in_channels // reduce_ratio
#         mu = np.random.randn(1, in_channels, k) * math.sqrt(2. / k)
#         mu = mu / (1e-6 + np.linalg.norm(mu, axis=1))
#         self.mu = Tensor(mu)
#         self.mu_momentum = mu_momentum
#
#         self.conv1 = Conv2d(in_channels, in_channels, 1)
#         self.conv2 = ConvBNReLU(in_channels, in_channels, 1, act='none')
#         self.bmm = P.BatchMatMul()
#         self.softmax = P.Softmax(axis=2)
#         self.relu = nn.ReLU()
#
#     def construct(self, x):
#         idn = x
#         # The first 1x1 conv
#         x = self.conv1(x)
#
#         # The EM Attention
#         b, c, h, w = x.size()
#         x = x.view(b, c, h * w)  # b * c * n
#         mu = self.mu.repeat(b, 1, 1)  # b * c * k
#
#         for i in range(self.stage_num):
#             x_t = x.permute(0, 2, 1)  # b * n * c
#             z = self.bmm(x_t, mu)  # b * n * k
#             z = self.softmax(z)  # b * n * k
#             z_ = z / (1e-6 + z.sum(dim=1, keepdim=True))
#             mu = self.bmm(x, z_)  # b * c * k
#             mu = self._l2norm(mu, dim=1)
#
#         # !!! The moving averaging operation is written in train.py, which is significant.
#
#         z_t = z.permute(0, 2, 1)  # b * k * n
#         x = mu.matmul(z_t)  # b * c * n
#         x = x.view(b, c, h, w)  # b * c * h * w
#         x = self.relu(x)
#
#         # The second 1x1 conv
#         x = self.conv2(x)
#         x = x + idn
#         x = self.relu(x)
#         if self.training:
#             with torch.no_grad():
#                 self.mu.mul_(self.mu_momentum).add_(mu.data.mean(dim=0, keepdim=True), alpha=1.0 - self.mu_momentum)
#         return x
#
#     def _l2norm(self, inp, dim):
#         """Normalize the inp tensor with l2-norm.
#         Returns a tensor where each sub-tensor of input along the given dim is
#         normalized such that the 2-norm of the sub-tensor is equal to 1.
#         Arguments:
#             inp (tensor): The input tensor.
#             dim (int): The dimension to slice over to get the sub-tensors.
#         Returns:
#             (tensor) The normalized tensor.
#         """
#         return inp / (1e-6 + np.linalg.norm(dim=dim, keepdim=True))


class NonLocalAttention(nn.Cell):
    def __init__(self, in_channels, reduce_ratio=2):
        super(NonLocalAttention, self).__init__()
        self.in_channels = in_channels
        self.inter_channels = max(in_channels // reduce_ratio, 1)

        self.g = Conv2d(self.in_channels, self.inter_channels, kernel_size=1, stride=1, padding=0)
        self.W = nn.SequentialCell(
            Conv2d(self.inter_channels, self.in_channels, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(self.in_channels),
        )
        self.theta = Conv2d(self.in_channels, self.inter_channels, kernel_size=1, stride=1, padding=0)
        self.phi = Conv2d(self.in_channels, self.inter_channels, kernel_size=1, stride=1, padding=0)
        self.matmul = P.BatchMatMul()
        self.div = P.Div()
        self.reshape = P.Reshape()
        self.add = P.TensorAdd()
        self.permute = P.Transpose()

    def construct(self, x):
        g_x = self.g(x)
        g_x = self.reshape(g_x, (g_x.shape[0], g_x.shape[1], -1))
        theta_x = self.theta(x)
        theta_x = self.reshape(theta_x, (theta_x.shape[0], theta_x.shape[1], -1))
        phi_x = self.phi(x)
        phi_x = self.reshape(phi_x, (phi_x.shape[0], phi_x.shape[1], -1))

        g_x = self.permute(g_x, (0, 2, 1))
        theta_x = self.permute(theta_x, (0, 2, 1))

        f = self.matmul(theta_x, phi_x)
        f_div_C = self.div(f, f.shape[-1])

        y = self.matmul(f_div_C, g_x)
        y = self.permute(y, (0, 2, 1))
        y = self.reshape(y, (x.shape[0], self.inter_channels, x.shape[2], x.shape[3]))
        z = self.add(self.W(y), x)
        return z


class DisentangledNonLocal2d(nn.Cell):
    """Disentangled Non-Local Blocks.
    Args:
        temperature (float): Temperature to adjust attention. Default: 0.05
    """

    def __init__(self,
                 in_channels,
                 reduce_ratio=2,
                 sub_sample=False,
                 use_scale=True,
                 temperature=0.05,
                 mode='embedded_gaussian',
                 **kwargs):
        super().__init__()
        self.in_channels = in_channels
        self.reduction = reduce_ratio
        self.use_scale = use_scale
        self.inter_channels = max(in_channels // reduce_ratio, 1)
        self.mode = mode
        if mode not in ['gaussian', 'embedded_gaussian', 'dot_product', 'concatenation']:
            raise ValueError("Mode should be in 'gaussian', 'concatenation', "
                             f"'embedded_gaussian' or 'dot_product', but got "
                             f'{mode} instead.')

        # g, theta, phi are defaulted as `nn.ConvNd`.
        # Here we use ConvModule for potential usage.
        self.g = nn.Conv2d(self.in_channels, self.inter_channels, 1)
        self.conv_out = ConvBNReLU(self.inter_channels, self.in_channels, kernel_size=1, act='none')
        if self.mode != 'gaussian':
            self.theta = nn.Conv2d(self.in_channels, self.inter_channels, kernel_size=1)
            self.phi = nn.Conv2d(self.in_channels, self.inter_channels, kernel_size=1, )

        if self.mode == 'concatenation':
            self.concat_project = nn.SequentialCell(
                Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
                nn.ReLU())
        # self.init_weights(**kwargs)

        if sub_sample:
            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
            self.g = nn.SequentialCell(self.g, max_pool_layer)
            if self.mode != 'gaussian':
                self.phi = nn.SequentialCell(self.phi, max_pool_layer)
            else:
                self.phi = max_pool_layer

        self.temperature = temperature
        self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1)

        self.matmul = P.BatchMatMul()
        self.div = P.Div()
        self.reshape = P.Reshape()
        self.add = P.TensorAdd()
        self.sub = P.Sub()
        self.permute = P.Transpose()
        self.pow = P.Pow()
        self.softmax = P.Softmax()
        self.mean = P.ReduceMean(keep_dims=True)

    def init_weights(self, std=0.01, zeros_init=True):
        nn.init.normal_(self.g.weight, std=std)
        if self.g.bias is not None:
            nn.init.zeros_(self.g.bias)
        nn.init.normal_(self.conv_out.conv.weight, std=std)
        if self.conv_out.conv.bias is not None:
            nn.init.zeros_(self.conv_out.conv.bias)
        if zeros_init:
            nn.init.zeros_(self.conv_out.bn.weight)
            nn.init.zeros_(self.conv_out.bn.bias)
        else:
            nn.init.ones_(self.conv_out.bn.weight)
            nn.init.zeros_(self.conv_out.bn.bias)

    def construct(self, x):
        # x: [N, C, H, W]
        n = x.shape[0]

        # g_x: [N, HxW, C]
        g_x = self.g(x)
        g_x = self.reshape(g_x, (n, self.inter_channels, -1))
        g_x = self.permute(g_x, (0, 2, 1))

        # theta_x: [N, HxW, C], phi_x: [N, C, HxW]
        if self.mode == 'gaussian':
            theta_x = self.reshape(x, (n, self.inter_channels, -1))
            theta_x = self.permute(theta_x, (0, 2, 1))
            if self.sub_sample:
                phi_x = self.phi(x)
                phi_x = self.reshape(phi_x, (n, self.inter_channels, -1))
            else:
                phi_x = self.reshape(x, (n, self.inter_channels, -1))
        elif self.mode == 'concatenation':
            theta_x = self.theta(x)
            theta_x = self.reshape(theta_x, (n, self.inter_channels, -1, 1))
            phi_x = self.phi(x)
            phi_x = self.reshape(phi_x, (n, self.inter_channels, -1, 1))
        else:
            theta_x = self.theta(x)
            theta_x = self.reshape(theta_x, (n, self.inter_channels, -1))
            theta_x = self.permute(theta_x, (0, 2, 1))
            phi_x = self.phi(x)
            phi_x = self.reshape(phi_x, (n, self.inter_channels, -1))
        # theta_x: [N, HxW, C], phi_x: [N, C, HxW]

        # subtract mean
        theta_x = self.sub(theta_x, self.mean(theta_x, -2))
        phi_x = self.sub(phi_x, self.mean(phi_x, -1))

        # pairwise_weight: [N, HxW, HxW]
        pairwise_weight = self.matmul(theta_x, phi_x)
        if self.use_scale:
            # theta_x.shape[-1] is `self.inter_channels`
            pairwise_weight = self.div(pairwise_weight, theta_x.shape[-1] ** 0.5)
        pairwise_weight = self.div(pairwise_weight, self.temperature)
        pairwise_weight = self.softmax(pairwise_weight)

        # y: [N, HxW, C]
        y = self.matmul(pairwise_weight, g_x)
        # y: [N, C, H, W]
        y = self.permute(y, (0, 2, 1))
        y = self.reshape(y, (n, self.inter_channels, x.shape[2], x.shape[3]))

        # unary_mask: [N, 1, HxW]
        unary_mask = self.conv_mask(x)
        unary_mask = self.reshape(unary_mask, (n, 1, -1))
        unary_mask = self.softmax(unary_mask)
        # unary_x: [N, 1, C]
        unary_x = self.matmul(unary_mask, g_x)
        # unary_x: [N, C, 1, 1]
        unary_x = self.permute(unary_x, (0, 2, 1))
        unary_x = self.reshape(unary_x, (n, self.inter_channels, 1, 1))

        output = self.add(x, self.conv_out(self.add(y, unary_x)))
        return output


class Attention(nn.Cell):
    def __init__(self, channels: dict, mid_channels=-1, out_channels=-1, num_conv_pre=0, num_conv_post=0, method='dnl',
                 down_stride=8, attention_cfg=None, add_fluff=False, fluff_reduction_ratio=16):
        super().__init__()
        self._channels = channels.copy()
        self.down_stride = down_stride
        in_channels = channels[down_stride]
        if attention_cfg is None:
            attention_cfg = {}

        self.pre_layers = []
        if add_fluff:
            if mid_channels <= 0:
                mid_channels = in_channels
            self.pre_layers.append(Fluff(in_channels, mid_channels, fluff_reduction_ratio))
            in_channels = mid_channels
        elif mid_channels > 0:
            self.pre_layers.append(ConvBNReLU(in_channels, mid_channels, 3, 1, 1))
            in_channels = mid_channels
        self.pre_layers.extend([ConvBNReLU(in_channels, in_channels, 3, 1, 1) for _ in range(num_conv_pre)])
        self.pre_layers = nn.SequentialCell(self.pre_layers)

        func = {
            'nl': NonLocalAttention,
            'dnl': DisentangledNonLocal2d,
            # 'ema': EMA,
        }[method]
        self.attention = func(in_channels=in_channels, **attention_cfg)

        self.post_layers = []
        self.post_layers.extend([ConvBNReLU(in_channels, in_channels, 3, 1, 1) for _ in range(num_conv_post)])
        if out_channels > 0:
            self.post_layers.append(ConvBNReLU(in_channels, out_channels, 3, 1, 1))
            in_channels = out_channels
        self.post_layers = nn.SequentialCell(self.post_layers)
        self._channels[down_stride] = in_channels

    def construct(self, inputs: dict):
        if self.down_stride == 8:
            x = inputs["8"]
        else:
            x = inputs["8"]
        x = self.pre_layers(x)
        x = self.attention(x)
        x = self.post_layers(x)
        if self.down_stride == 8:
            inputs["8"] = x
        return inputs

    @property
    def channels(self):
        return self._channels


def _test():
    import mindspore.context as context
    context.set_context(device_target="GPU")

    x = Tensor(np.random.rand(2, 32, 64, 64).astype(np.float32))
    net = NonLocalAttention(32)
    y = net.compile_and_run(x)
    print(y.shape)

    x = Tensor(np.random.rand(2, 64, 32, 32).astype(np.float32))
    net = DisentangledNonLocal2d(64)
    y = net.compile_and_run(x)
    y = net(x)
    print(y.shape)


def _test2():
    import torch
    import mindspore.context as context
    from detection.attachment.attention import Attention as Attention_torch
    from src import convert
    from mindspore.train.serialization import load_param_into_net
    context.set_context(device_target="GPU")

    x = np.random.rand(2, 2048, 256 // 8, 256 // 8).astype(np.float32)
    att = Attention(
        channels={8: 2048},
        down_stride=8,
        method="dnl",
        mid_channels=256,
        out_channels=256,
        num_conv_pre=1,
        num_conv_post=1,
        attention_cfg={'reduce_ratio': 2},
        add_fluff=True,
        fluff_reduction_ratio=4)
    y = att({"8": Tensor(x)})
    print(y["8"].shape)
    # att.compile_and_run({"8": Tensor(x)})
    att2 = Attention_torch(
        channels={8: 2048},
        down_stride=8,
        method="dnl",
        mid_channels=256,
        out_channels=256,
        num_conv_pre=1,
        num_conv_post=1,
        attention_cfg={'reduce_ratio': 2},
        add_fluff=True,
        fluff_reduction_ratio=4
    )
    att2.eval()
    load_param_into_net(att, convert(att2.state_dict()))
    y2 = att2({8: torch.from_numpy(x)})
    y = att({"8": Tensor(x)})
    print({k: v.shape for k, v in y2.items()})
    for k in y.keys():
        print(k, np.abs(y[k].asnumpy() - y2[int(k)].cpu().detach().numpy()).max())


if __name__ == '__main__':
    _test2()
